Source: ui/GraphView.jsx

import React, { Component } from 'react';
import * as d3 from "d3";
import {
    Typography, FormControl, Select, MenuItem, Link,
    InputLabel, TextField, Fab, FormControlLabel, Checkbox
} from "@material-ui/core";

import "./GraphView.css";
import "../Home0.css";
import Graph from '../scene/Graph.js'

const fieldToLabel = {
  name: 'Name',
  class_name: 'Op',

  kernel_size: 'Kernel size',
  pool_size: 'Pool size',
  filters: 'Filters',
  strides: 'Strides',
  activation: 'Activation',
  units: 'Units',
  axis: 'Axis',
  batch_input_shape: 'Input shape',
  dtype: 'Dtype',
}

const classNameToConfigFields = {
  Conv2D: ['kernel_size', 'strides', 'filters'],
  Dense: ['activation', 'units'],
  Concatenate: ['axis'],
  MaxPooling2D: ['pool_size', 'strides'],
  Input: ['batch_input_shape', 'dtype'],
  Softmax: ['axis'],
  ReLU: []
}

const getTooltipHTML = (layer) => {
  const cN = layer.class_name;
  const config = layer.config;
  let configFields = [];
  if(cN in classNameToConfigFields) {
    configFields = classNameToConfigFields[cN];
  }
  let html =
  '<div style="padding: 2px 5px; display: flex; flex-direction: row;">'
  +'<div style="padding: 2px 5px;" >'
  +'Name: <br/>Type: <br/>';
  for(const cF of configFields) {
    html += fieldToLabel[cF] + ':<br/>';
  }
  html += '</div>'
  +'<div style="padding: 2px 5px;" >';
  html += layer.name + '<br/>'+cN+'<br/>';
  for(const cF of configFields) {
    const value =  config[cF];
    html += value + '<br/>';
  }
  html += '</div>'
  html += '</div>';
  return html;
}

const draw = (props) => {
  d3.select("#graphContainer > *").remove();

  const graph = props.graph;
  const distList = graph.getLayoutByInputDist();
  let layerList = graph.getSortedLayerList();

  const graphMargin = ({ top: 40, right: 40, bottom: 40, left: 40 });
  const graphWidth = 1000 - graphMargin.left - graphMargin.right
  const graphHeight = 1500 - graphMargin.top - graphMargin.bottom
  let zoomScale = 1;
  const filterTransitionSpeed = 1000;

  let zoom = d3.zoom()
    .scaleExtent([.1, 3.5])
    .extent([[0, 0], [graphWidth, graphHeight]])
    .on("zoom", zoomed);

  function zoomed() {
      d3.select('#graphG').attr("transform", d3.event.transform);
      // console.log(d3.event.transform)
  }

  let graphSVG = d3.select("#graphContainer")
    .append("svg")
    .attr('viewBox', '0 0 ' + (graphWidth + graphMargin.left + graphMargin.right) + ' ' + (graphHeight + graphMargin.top + graphMargin.bottom))
    .attr('width', '100%')
    .style('border-bottom', '1px solid rgba(0, 0, 0, 0.1)')
    .attr('id', 'graphSVG');

  let filter = graphSVG.append('filter').attr('id', 'dilate');
  let feMorphology = filter.append('feMorphology')
  .attr('operator', 'dilate')
  .attr('radius', 10);

  let ttDiv = d3.select("body").append("div")
    .attr("class", "graphViewToolTip")
    .style("opacity", 0);

  let zoomRect = graphSVG.append("rect")
    .attr("width", graphWidth + graphMargin.left + graphMargin.right)
    .attr("height", graphHeight + graphMargin.top + graphMargin.bottom)
    .style("fill", "none")
    .style("pointer-events", "all")
    .call(zoom);

  let graphG = graphSVG
            .append("g")
            .attr("transform", "translate(" + graphMargin.left + "," + graphMargin.top + ")")
            .attr('id', 'graphG')

  function drawOrigin() {
            graphG.append('circle')
                .attr('r', 10)
                .attr('cx', 0)
                .attr('cy', 0)
        }

  drawOrigin()

  function centerDag() {
      zoomRect.transition().duration(750).call(zoom.transform, d3.zoomIdentity.translate(graphWidth / 2, 50).scale(0.2));
  }
  centerDag()
  d3.select('#graph-home').on('click', () => {
      centerDag()
  })

  const fvWidth = 120;
  const fvHeight = fvWidth/4;

  const deWidth = 49;
  const deHeight = deWidth;

  const attrFvWidth = 60;
  const attrFvHeight = attrFvWidth;

  let layerVerticalSpace = 150;
  let fvHorizontalSpace = 50;

  const computeNodeCoordinates = (distList) => {
    distList.forEach((el, distInd) => {
      const {distance: dist, layers:layers} = el;
      layers.forEach((layer, i) => {
        if(layer.inboundNodes.length === 1 &&
          layer.inboundNodes[0].outboundNodes.length === 1) {
          layer.x = layer.inboundNodes[0].x;
        } else {
          layer.x =
          (
            (
              (fvWidth + fvHorizontalSpace) * i
            ) -
            (
              (layers.length * fvWidth + (layers.length - 1)
                * fvHorizontalSpace
              ) / 2
            )
          );
        }
        layer.y = dist * layerVerticalSpace;
        layer.midX = layer.x + fvWidth / 2;
        layer.inY = layer.y;
        layer.outY = layer.y + fvHeight;
      });
    });
  }

  const nodeTextX = function (d) {
    const bb = this.getBBox();
    return d.x + fvWidth / 2 - bb.width / 2;
  };

  const nodeTextY = function (d) {
    const bb = this.getBBox();
    return d.y + fvHeight / 2 + bb.height / 2 - 3;
  };

  const nodeNameY = function (d) {
    const bb = this.getBBox();
    return d.y + fvHeight + bb.height + 5;
  };

  function nodeColor(d){
    const colors = {
      'Conv2D': 'lightcoral',
      'MaxPool2D': 'honeydew',
      'AvgPool2D': 'honeydew',
      'MaxPooling2D': 'honeydew',
      'AveragePooling2D': 'honeydew',
      'GlobalAveragePooling2D': 'honeydew',
      'ReLU': 'lightblue',
      'Concat': 'lightgray',
      'Concatenate': 'lightgray',
      'InputLayer': 'white',
      'Flatten': 'lightgray',
      'Reshape': 'lightgray',
      'Softmax': 'lightblue',
      'Dense': 'lightyellow',
    };
    let col = colors[d.class_name];
    if(!col){
      col = 'lightgray';
    }
    if(d3.select(this).classed('selected')){
      col = 'greenyellow';
    }
    return col;
  }

  computeNodeCoordinates(distList);

  const layerNodes = graphG
  .selectAll("g")
  .data(layerList)
  .enter()
  .append("g")
  .attr('id', (d,i) => 'layerNode-'+d.name)
  .attr('class', 'layerNode')

  const drawInputConnections = function(dThis, i) {
    d3.select(this).selectAll('path')
    .data(dThis.inboundNodes)
    .enter()
    .append('path')
    .attr('d', dIn => {
      let inX = dIn.midX;
      let inY = dIn.outY;
      let outX = dThis.midX;
      let outY = dThis.inY;

      return "M" + inX + "," + inY
          + "C" + inX + " " + (outY - layerVerticalSpace/2) + ","
          + outX + " " + (outY - layerVerticalSpace/2) + ","
          + outX + "," + outY;
    })
    .style('stroke-width', 2)
    .style('stroke', 'darkgray')
    .style('fill', 'transparent');
  }

  layerNodes.each(drawInputConnections);

  const layerNodeRects = layerNodes.append('rect')
  .attr("x",(d)=>d.x)
  .attr("y",(d)=>d.y)
  .attr("width", fvWidth)
  .attr("height", fvHeight)
  .style('fill', nodeColor)
  .style('stroke', 'darkgray')
  .style('stroke-width', 1)
  .attr('rx', 10)
  .attr('id', (d,i) => 'layerNodeRect-'+d.name)
  .attr('class', 'layerNodeRect')
  .classed('selected', false)
  .on('mouseover', function(d, i) {
    d3.select(this.parentNode).select('.layerNodeName')
    .style('visibility', 'visible');

    ttDiv.transition()
        .duration(200)
        .style("opacity", .9);
    ttDiv.html(() => getTooltipHTML(d))
        .style("left", (d3.event.pageX + 20) + "px")
        .style("top", (d3.event.pageY - 28) + "px");

    const currEl = d3.select(this);
    currEl
    .style('stroke-width', currEl.classed("selected") ? 5 : 2);
  })
  .on('mouseout', function(d, i) {
    d3.select(this.parentNode).select('.layerNodeName')
    .style('visibility', 'hidden');

    ttDiv.transition()
        .duration(500)
        .style("opacity", 0);

    const currEl = d3.select(this);
    currEl
    .style('stroke-width', currEl.classed("selected") ? 5 : 1);
  })
  .on('click', function(d, i) {
    props.clickedNode(d);
    layerNodeRects
    .style('stroke-width', 1)
    .style('stroke', 'darkgray')
    .classed('selected', false)
    .style('fill', nodeColor);

    d3.select(this)
    .style('stroke-width', 5)
    .style('stroke', 'black')
    .classed('selected', true)
    .style('fill', nodeColor);
  });


  const layerNodeOpText = layerNodes.append('text')
  .text(d => d.class_name)
  .style('font-size', fvHeight/2)
  .attr('class', 'layerNodeText')
  .attr('x', nodeTextX)
  .attr('y', nodeTextY)
  .style('fill', 'black')
  .attr("pointer-events", "none");

  const layerNodeNameText = layerNodes.append('text')
  .text(d => d.name)
  .style('font-size', fvHeight/2)
  .attr('class', 'layerNodeText layerNodeName')
  .attr('x', nodeTextX)
  .attr('y', nodeNameY)
  .style('visibility', 'hidden')
  .style('fill', 'black')
  .attr("pointer-events", "none");

}

/**
 * Component containing the graph visualization of the model.
 */
class GraphView extends Component {

  componentDidMount(){
    if(this.props.graph){
      draw(this.props);
    }
  }

  componentDidUpdate(oldProps){
    if(this.props.graph && oldProps.graph !== this.props.graph){
      draw(this.props);
    }
  }

  x(data, index){
    console.log(data, index);
  }

  render(){
    return (
    <div  style={{width: "100%", minHeight:"300px",
    height: "100%", overflow: "hidden"}} id="graphView">
      <div style={{width:"100%", height:"100%"}}
      id="graphContainer"></div>
    </div>);
  }

}
export default  GraphView;